

rm(list = ls(all=TRUE))
cat("\f")

library(rootSolve)
library(nloptr)
library(entropy)
library(LaplacesDemon) # KLD function 

source("generate_data.R")
source("process_data.R")
source("compute_nde.R") # { G-formula={Y,M}, IPW={M,A}, Mixed={A,Y}, AIPW={A,M,Y} }
source("compute_mse.R") # includes {estimate_y} for reparam estimation of Y

# source("compute_kl_icml.R")

set.seed(0)

# ################################################
# Data
# ################################################

# initialization 
n = 5000 

nde_0.vec <- nde_0hat.vec <- nde_1.vec <- nde_2.vec <- nde_3.vec <- nde_4.vec <- c()
mse_0.vec <- mse_1.vec <- mse_2.vec <- mse_3.vec <- mse_4.vec <- c()
mle_0.vec <- mle_1.vec <- mle_2.vec <- mle_3.vec <- mle_4.vec <- c()
kl_0.vec <- kl_1.vec <- kl_2.vec <- kl_3.vec <- kl_4.vec <- c()

trials = 100

for(t in 1:trials){

cat("\n \n trial = ", t, "\n \n")

# generate data
data_obj = generate_data(n)
dat = data_obj$data
beta_y0 = data_obj$beta_y
beta_m0 = data_obj$beta_m
beta_a0 = data_obj$beta_a

beta_0 = list(beta_y = beta_y0, 
              beta_m = beta_m0, 
              beta_a = beta_a0)


# Calculate p(C, A, M, Y)
prob_truth = data_obj$prob
mle_0 = data_obj$mle

attach(dat, warn.conflicts = FALSE)

# generate test data 
n1 = floor(n*0.8) # 20% test data
idx_test = (n1+1):n
Y_test = dat$Y[idx_test]

dat$Y[idx_test] = NA

# Formulas
fmla_y = as.formula(Y ~ C + A + AC + M + CM + AM + ACM)
fmla_m = as.formula(M ~ C + A + AC)
fmla_a = as.formula(A ~ C)
fmla = list(fmla_y = fmla_y, fmla_m = fmla_m, fmla_a = fmla_a)


# ################################################
# (0) Unconstrained MLE 
# ################################################

source("0_unconstrainedMLE.R")

# Initialization
opt_0 = list(reparam = FALSE, 
             estimator = "G-formula")
px_0 = rep(1/n, n)

# Fitting 
model_0hat = optimize_unconst(dat, idx_test, fmla, px_0)

# Evaluations
beta_0hat = list(beta_a = model_0hat$beta_a, beta_m = model_0hat$beta_m, beta_y = model_0hat$beta_y)
mle_0hat = model_0hat$mle
nde_0hat = compute_effect(dat, beta_0hat, px_0, opt_0)   
mse_0 = mean((Y_test - model_0hat$Yhat[idx_test])^2)
kl_0 = (1/n)*(mle_0 - mle_0hat) #+ (1/n)*(sum(log(prob_truth$p_c)) - sum(log(px_0)))

# True effect
nde_0 = compute_effect(dat, beta_0, px_0, opt_0)

# ################################################
# (1) Constrained MLE 
# ################################################

source("1_constrainedMLE.R")

# Initialization
opt_1 = list(reparam = FALSE, 
             estimator = "G-formula", 
             tau_l = -0.05, 
             tau_u = 0.05)
             # tau_l = -0.00001, 
             # tau_u = 0.00001)
px_1 = rep(1/n, n)
func = compute_effect
# beta_start_1 = NULL
beta_start_1 = c(beta_0hat$beta_m, beta_0hat$beta_y)

# Fitting
model_1 = optimize_nloptr(dat, idx_test, func, beta_start_1, px_1, fmla, opt_1)

# Evaluations
beta_1 = list(beta_a = model_1$beta_a, beta_m = model_1$beta_m, beta_y = model_1$beta_y)
mle_1 = model_1$mle
nde_1 = compute_effect(dat, beta_1, px_1, opt_1)
mse_1 = compute_mse(dat, idx_test, Y_test, beta_1, weights_1, opt_1)
kl_1 = (1/n)*(mle_0 - mle_1) #+ (1/n)*(sum(log(prob_truth$p_c)) - sum(log(px_1)))



# # ################################################
# # (2) Reparameterized MLE
# # ################################################

Xa = as.matrix(model.matrix(fmla_a, data=model.frame(dat, na.action = NULL)))
p_A1 = 1/(1+exp(-Xa%*%beta_0hat$beta_a))
p_A = A*p_A1 + (1-A)*(1-p_A1)


source("2_reparamMLE.R")

# Initialization
opt_2 = list(reparam = TRUE, 
             estimator = "G-formula")
px_2 = rep(1/n, n)
beta_start_2 = NULL
fmla_f = as.formula(Y ~ -1 + C + AC + M + CM + AM + ACM) # no "intercept" and no "A"

# Fitting
model_2 = optimize_reparam(dat, idx_test, fmla_f, fmla_m, px_2, beta_start_2)

# Evaluations
beta_2 = list(beta_a = beta_0hat$beta_a, beta_m = model_2$beta_m, beta_y = model_2$beta_y)
mle_YM_2 = model_2$mle 
mle_2 = mle_YM_2 + sum(log(px_2) + log(p_A)) 

nde_2 = compute_effect(dat, beta_2, px_2, opt_2)
mse_2 = compute_mse(dat, idx_test, Y_test, beta_2, px_2, opt_2)
kl_2 = (1/n)*(mle_0 - mle_2) #+ (1/n)*(sum(log(prob_truth$p_c)) - sum(log(px_2)))


# ################################################
# (3) hybrid likelihood - Batch Prediction  
# ################################################

source("3_hybridMLE.R")

# Initialization
opt_3 = list(reparam = FALSE,
             estimator = "G-formula", 
             alpha = 0.001, # alpha: initial step size
             threshold = 0.0005, # threshold: stopping criterion
             max_iter = 10000, # max_iter: maximum number of iterations
             delta = 0.001) # delta: for numerical differentiation

# beta_start_3 = rep(0.1, length(beta_m0) + length(beta_y0))
beta_start_3 = c(beta_0hat$beta_m, beta_0hat$beta_y)
# beta_start_3 = c(beta_1$beta_m, beta_1$beta_y)

# Fitting
model_3 = optimize_hybrid(beta_start_3, dat, idx_test, fmla_m, fmla_y, opt_3)
px_3 = model_3$px

# Evaluations
beta_3 = list(beta_m = model_3$beta_m, beta_y = model_3$beta_y, beta_a = beta_0hat$beta_a)
mle_YMX_3 = model_3$mle
mle_3 = mle_YMX_3 + sum(log(p_A))

nde_3 = compute_effect(dat, beta_3, px_3, opt_3)
nde_3_b = model_3$nde
mse_3 = mean((Y_test - model_3$Y_hat[idx_test])^2)
kl_3 = (1/n)*(mle_0 - mle_3) #+ (1/n)*(sum(log(prob_truth$p_c)) - sum(log(px_3)))


# ################################################
# (4) Unified: profiling pi 
# ################################################

# source("3_hybridMLE.R") # get_pi(), get_lambda()
source("4_unified.R")

# Initialization
opt_4 = list(reparam = TRUE,
             estimator = "G-formula",
             alpha = 0.001, # alpha: initial step size
             threshold = 1e-4, # threshold: stopping criterion
             max_iter = 100, # max_iter: maximum number of iterations
             delta = 0.0001) # delta: for numerical differentiation

beta_start_4 = NULL
# beta_start_4 = list(beta_m = beta_2$beta_m, beta_y = beta_2$beta_y)
# beta_start_4$beta_y[8] = 0.1
# beta_start_4 = list(beta_m = beta_0hat$beta_m, beta_y = beta_0hat$beta_y)

px_start_4 = px_3
# px_start_4 = c(1, rep(0, n - 1))
# px_start_4 = c(rep(1/(2*n), n/2), 3/4, rep(0, n - n/2 + 1))

# Fitting
model_4 = optimize_unified(dat, idx_test, fmla_f, fmla_m, px_start_4, beta_start_4, opt_4)
px_4 = model_4$px

# Evaluations
beta_4 = list(beta_m = model_4$beta_m, beta_y = model_4$beta_y, beta_a = beta_0hat$beta_a)
mle_YM_4 = model_4$mle
mle_4 = mle_YM_4 + sum(log(px_4) + log(p_A)) 

nde_4 = compute_effect(dat, beta_4, px_4, opt_4)
nde_4_b = model_4$nde

mse_4 = mean((Y_test - model_4$Yhat[idx_test])^2)
kl_4 = (1/n)*(mle_0 - mle_4) #+ (1/n)*(-n*log(n) - sum(log(px_4)))


# ################################################
# Gather the results
# ################################################

cat("\n",
    " ++++++++ NDE using G-formula +++++++++ \n \n",
    "Unfair NDE = ", nde_0, "\n",
    "(M0) Unfair NDE - Unconstrained = ", nde_0hat, "\n",
    "(M1) Fair NDE - Constrained = ", nde_1, "\n",
    "(M2) Fair NDE - Reparam = ", nde_2, "\n",
    "(M3) Fair NDE - Hybrid = ", nde_3, "\n",
    "(M4) Fair NDE - Unified = ", nde_4, "\n \n",
    # 
    " ++++++++ Mean Squared Eroor +++++++++ \n \n",
    "(M0) MSE - Unconstrained = ", mse_0, "\n",
    "(M1) MSE - Constrained = ", mse_1, "\n",
    "(M2) MSE - Reparam = ", mse_2, "\n",
    "(M3) MSE - Hybrid = ", mse_3, "\n",
    "(M4) MSE - Unified = ", mse_4, "\n \n",
    # 
    " ++++++++ Log MLE +++++++++ \n \n",
    "(M0) MLE - Unconstrained = ", mle_0, "\n",
    "(M1) MLE - Constrained = ", mle_1, "\n",
    "(M2) MLE - Reparam = ", mle_2, "\n",
    "(M3) MLE - Hybrid = ", mle_3, "\n",
    "(M4) MLE - Unified = ", mle_4, "\n \n", 
    # 
    " ++++++++ KL +++++++++ \n \n",
    "(M0) KL - Unconstrained = ", kl_0, "\n",
    "(M1) KL - Constrained = ", kl_1, "\n",
    "(M2) KL - Reparam = ", kl_2, "\n",
    "(M3) KL - Hybrid = ", kl_3, "\n",
    "(M4) KL - Unified = ", kl_4, "\n \n")


nde_0.vec <- c(nde_0.vec, nde_0)
nde_0hat.vec <- c(nde_0hat.vec, nde_0hat)
nde_1.vec <- c(nde_1.vec, nde_1)
nde_2.vec <- c(nde_2.vec, nde_2)
nde_3.vec <- c(nde_3.vec, nde_3)
nde_4.vec <- c(nde_4.vec, nde_4)

mse_0.vec <- c(mse_0.vec, mse_0)
mse_1.vec <- c(mse_1.vec, mse_1)
mse_2.vec <- c(mse_2.vec, mse_2)
mse_3.vec <- c(mse_3.vec, mse_3)
mse_4.vec <- c(mse_4.vec, mse_4)

mle_0.vec <- c(mle_0.vec, mle_0)
mle_1.vec <- c(mle_1.vec, mle_1)
mle_2.vec <- c(mle_2.vec, mle_2)
mle_3.vec <- c(mle_3.vec, mle_3)
mle_4.vec <- c(mle_4.vec, mle_4)

kl_0.vec <- c(kl_0.vec, kl_0)
kl_1.vec <- c(kl_1.vec, kl_1)
kl_2.vec <- c(kl_2.vec, kl_2)
kl_3.vec <- c(kl_3.vec, kl_3)
kl_4.vec <- c(kl_4.vec, kl_4)

}

results = round(data.frame(nde_0 = nde_0hat.vec,
                     nde_1 = nde_1.vec,
                     nde_2 = nde_2.vec,
                     nde_3 = nde_3.vec,
                     nde_4 = nde_4.vec,
                     mse_0 = mse_0.vec,
                     mse_1 = mse_1.vec,
                     mse_2 = mse_2.vec,
                     mse_3 = mse_3.vec,
                     mse_4 = mse_4.vec,
                     mle_0 = mle_0.vec,
                     mle_1 = mle_1.vec,
                     mle_2 = mle_2.vec,
                     mle_3 = mle_3.vec,
                     mle_4 = mle_4.vec,
                     kl_0 = kl_0.vec,
                     kl_1 = kl_1.vec,
                     kl_2 = kl_2.vec,
                     kl_3 = kl_3.vec,
                     kl_4 = kl_4.vec), 6)

write.csv(results, "results.csv", row.names = F)


# ################################################
# Comparisons
# ################################################

cat("\n",
    " ++++++++ NDE using G-formula +++++++++ \n \n",
    "Unfair NDE = ", nde_0.vec, "\n",
    "(M0) Unfair NDE - Unconstrained = ", nde_0hat.vec, "\n",
    "(M1) Fair NDE - Constrained = ", nde_1.vec, "\n",
    "(M2) Fair NDE - Reparam = ", nde_2.vec, "\n",
    "(M3) Fair NDE - Hybrid = ", nde_3.vec, "\n",
    "(M4) Fair NDE - Unified = ", nde_4.vec, "\n \n",
    #
    " ++++++++ Mean Squared Eroor +++++++++ \n \n",
    "(M0) MSE - Unconstrained = ", mse_0.vec, "\n",
    "(M1) MSE - Constrained = ", mse_1.vec, "\n",
    "(M2) MSE - Reparam = ", mse_2.vec, "\n",
    "(M3) MSE - Hybrid = ", mse_3.vec, "\n",
    "(M4) MSE - Unified = ", mse_4.vec, "\n \n",
    #
    " ++++++++ Log MLE +++++++++ \n \n",
    "(M0) MLE - Unconstrained = ", mle_0.vec, "\n",
    "(M1) MLE - Constrained = ", mle_1.vec, "\n",
    "(M2) MLE - Reparam = ", mle_2.vec, "\n",
    "(M3) MLE - Hybrid = ", mle_3.vec, "\n",
    "(M4) MLE - Unified = ", mle_4.vec, "\n \n",
    #
    " ++++++++ KL +++++++++ \n \n",
    "(M0) KL - Unconstrained = ", kl_0.vec, "\n",
    "(M1) KL - Constrained = ", kl_1.vec, "\n",
    "(M2) KL - Reparam = ", kl_2.vec, "\n",
    "(M3) KL - Hybrid = ", kl_3.vec, "\n",
    "(M4) KL - Unified = ", kl_4.vec, "\n \n")


cat("\n",
    " ++++++++ NDE using G-formula +++++++++ \n \n",
    "Unfair NDE = ", mean(nde_0.vec), "\n",
    "(M0) Unfair NDE - Unconstrained = ", mean(nde_0hat.vec), "\n",
    "(M1) Fair NDE - Constrained = ", mean(nde_1.vec), "\n",
    "(M2) Fair NDE - Reparam = ", mean(nde_2.vec), "\n",
    "(M3) Fair NDE - Hybrid = ", mean(nde_3.vec), "\n",
    "(M4) Fair NDE - Unified = ", mean(nde_4.vec), "\n \n",
    #
    " ++++++++ Mean Squared Eroor +++++++++ \n \n",
    "(M0) MSE - Unconstrained = ", mean(mse_0.vec), "\n",
    "(M1) MSE - Constrained = ", mean(mse_1.vec), "\n",
    "(M2) MSE - Reparam = ", mean(mse_2.vec), "\n",
    "(M3) MSE - Hybrid = ", mean(mse_3.vec), "\n",
    "(M4) MSE - Unified = ", mean(mse_4.vec), "\n \n",
    #
    " ++++++++ Log MLE +++++++++ \n \n",
    "(M0) MLE - Unconstrained = ", mean(mle_0.vec), "\n",
    "(M1) MLE - Constrained = ", mean(mle_1.vec), "\n",
    "(M2) MLE - Reparam = ", mean(mle_2.vec), "\n",
    "(M3) MLE - Hybrid = ", mean(mle_3.vec), "\n",
    "(M4) MLE - Unified = ", mean(mle_4.vec), "\n \n",
    #
    " ++++++++ KL +++++++++ \n \n",
    "(M0) KL - Unconstrained = ", mean(kl_0.vec), "\n",
    "(M1) KL - Constrained = ", mean(kl_1.vec), "\n",
    "(M2) KL - Reparam = ", mean(kl_2.vec), "\n",
    "(M3) KL - Hybrid = ", mean(kl_3.vec), "\n",
    "(M4) KL - Unified = ", mean(kl_4.vec), "\n \n")





